import numpy as np
import math
from scipy.spatial.transform import Rotation
from scipy.interpolate import splprep, splev

def skew(x):
    x_skew = np.zeros((3, 3))
    x_skew[0, 1] = -x[2]
    x_skew[0, 2] = x[1]
    x_skew[1, 0] = x[2]
    x_skew[1, 2] = -x[0]
    x_skew[2, 0] = -x[1]
    x_skew[2, 1] = x[0]
    return x_skew

def so3_to_SO3(so3):
    theta = np.linalg.norm(so3)
    axis = so3/theta
    Askew = skew(axis)

    R = np.cos(theta)*np.eye(3) + np.sin(theta) * Askew + (1.0 - np.cos(theta)) \
      * np.multiply(np.array([axis]).T, axis)
    return R


def SO3_to_so3_hat(R):
    theta = np.arccos((np.trace(R) - 1) / 2)
    if abs(theta) < 1e-6:
        omega_hat = np.zeros((3, 3))
    else:
        omega_hat = (theta / (2 * np.sin(theta))) * (R - R.T)
    return omega_hat

def SO3_to_so3(R):
    '''
    theta = np.arccos((np.trace(R) - 1) / 2)
    if abs(theta) < 1e-6:
        omega = np.zeros(3)
    else:
        omega = (theta / (2 * np.sin(theta))) * np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]])
    '''
    skew_omega = SO3_to_so3_hat(R)
    omega = np.array([skew_omega[2,1], skew_omega[0,2], skew_omega[1,0]])
    return omega


def generate_S_trajectory(radius, num_points):
    trans = []

    ts = np.linspace(0, 2*np.pi, num_points)
    for t in ts:
        x = radius * np.sin(t)
        y = radius * np.sin(t) * np.cos(t)
        z = radius * np.cos(t)
        trans.append(np.array([x, y, z]).T)

    #noise = np.random.normal(0, 0.001, 3);

    # Generate random quaternions for each point
    quats = []
    # first euler
    euler = np.random.uniform(-math.pi/8, math.pi/8, size=(3,))
    num = num_points/10;
    for i in range((int)(num)):
        add = np.random.uniform(-2*math.pi/num/10, 2*math.pi/num, size=(3,))
        euler += add

        r = Rotation.from_euler('xyz', euler)
        q = r.as_quat()
        quats.append(q)

    # Smooth the quaternions using spline interpolation
    interp_quats = smooth_quaternions(quats, 10)
    return trans, interp_quats

def smooth_quaternions(quat, times):
    # Convert quaternions to Euler angles
    euler = []
    for q in quat:
        r = Rotation.from_quat(q)
        euler.append(r.as_euler('xyz'))

    # Perform spline interpolation on Euler angles
    #t = np.arange(len(euler))

    splines, mt = splprep(np.array(euler).T, s=0)

    interp_t = np.linspace(0, 1, len(euler)*times)
    interp_euler = splev(interp_t, splines)

    # Convert interpolated Euler angles back to quaternions
    interp_quat = []
    for euler in np.array(interp_euler).T:
        r = Rotation.from_euler('xyz', euler)
        q = r.as_quat()
        interp_quat.append(q)
    return interp_quat


'''
s S t
'''
def transform_trajectory(trans, quats, rotation, translation, s):

    r = rotation.as_matrix()
    t = translation

    rotated_quat = []
    transformed_trans = []
    for i in range(len(trans)):
        r1 = Rotation.from_quat(quats[i]).as_matrix()
        r2 = r1 @ r.T 
        rotated_quat.append(Rotation.from_matrix(r2).as_quat())

        t1 = trans[i]
        t2  = s * r @ t1 + t 
        transformed_trans.append(t2)

    return transformed_trans, rotated_quat




#add dirft, s, S and deltaT between the two body frame
'''
pose:[R,t] Twc = [RT, t]
T2wc = T21w*T1wc*{dT21c}-1
dT: [s, S, dR, dt]
T:[s, S, R, t]

R2 = S*dR*R1*RT*S
t2 = -S*R*R1T*dRT*S*dt + sS*R*t1 + t
'''
def Stransform_trajectory(trans, quats, rotation, translation, s, Is_add_drift = True):

#TODO:S change here
    S = np.identity(3, dtype = int);
    #S[2,2] = -1;

#TODO:dR&t change here
    #dr_angle = np.random.uniform(-math.pi/4, math.pi/4)
    dr_angle = np.random.uniform(-2*math.pi, 2*math.pi)

    dt = np.random.uniform(-1, 1, size=(3,))
    #dt = [0,0,0]

    # Apply random rotation
    dr_vector = np.random.uniform(-6, 6, size=(3,))
    dr_vector /= np.linalg.norm(dr_vector)
    dr = Rotation.from_rotvec(dr_angle * dr_vector).as_matrix()

    print("dr angle: \n", dr_angle)

    print("dR:\n", dr)
    print("dt:\n", dt)

    r = rotation.as_matrix()
    t = translation

    drift_r_euler = [0, 0, 0]
    drift_t = [0, 0, 0]

    rmse_r_euler = [0, 0, 0]
    rmse_t = 0
    rmse_r = 0

    rotated_quat = []
    transformed_trans = []
    for i in range(len(trans)):
        if Is_add_drift:

#TODOLnoise model change here
            noise_r = np.random.normal(0.0002, 0.001, 3);
            drift_r_euler += noise_r
            noise_t = np.random.normal(0.0002, 0.001, 3);
            drift_t += noise_t
            rmse_r_euler += np.abs(drift_r_euler)
            rmse_t += np.sqrt(drift_t.T @ drift_t)
            rmse_r += np.sqrt(drift_r_euler.T @ drift_r_euler)



        drift_r = Rotation.from_euler('xyz', drift_r_euler).as_matrix()

        r1 = drift_r @ Rotation.from_quat(quats[i]).as_matrix()
        r2 = S @ dr @ r1 @ r.T @ S
        t1 = trans[i] + drift_t
        t2 = -r2.T @ dt + s*S @ r @ t1 + t

        rotated_quat.append(Rotation.from_matrix(r2).as_quat())
        transformed_trans.append(t2)



        '''
        r20 = Rotation.from_quat(rotated_quat[0]).as_matrix()
        test = t2 - (s*S@r@t1 + t - r2.T@dt)
        print("test: ", test)
        test = transformed_trans[0] - (s*S@r@trans[0] + t - r20.T@dt)
        print("test0: ", test)
        test = s*S@r@(t1-trans[0]) - (r2.T - r20.T) @ dt - (t2 - transformed_trans[0])
        print("tst: ", test)

        if(i == 1):
            print("s*S@r:\n", s*S@r)
            print("ts0: ", trans[0])
            print("tt0: ", transformed_trans[0])
            r20 = Rotation.from_quat(rotated_quat[0]).as_matrix()
            print("Rt0: ", r20)
            print("ts: ", t1)
            print("tt: ", t2)
            print("Rt: ", r2)
            test = s*S@r@(t1-trans[0]) - (r2.T - r20.T) @ dt - (t2 - transformed_trans[0])
            print("tst: ", test)
            print("111: ", s*S@r@(t1-trans[0]))
            print("222: ", - (r2.T - r20.T) @ dt)
            print("2220: ", r2.T - r20.T)
            print("333: ", - (t2 - transformed_trans[0]))
            
        '''



    rmse_r_euler[0] /= len(quats)
    rmse_r_euler[1] /= len(quats)
    rmse_r_euler[2] /= len(quats)
    rmse_t /= len(trans)
    rmse_r /= len(quats)

    print("S:\n", S)
    print("rmse: rx ry rz r t: ", rmse_r_euler[0], rmse_r_euler[1], rmse_r_euler[2], rmse_r, rmse_t)

    return transformed_trans, rotated_quat



def save_tum_trajectory(filename, trans, quats):
    with open(filename, 'w') as f:
        for i in range(len(trans)):
            line = f"{i} {trans[i][0]} {trans[i][1]} {trans[i][2]} {quats[i][0]} {quats[i][1]} {quats[i][2]} {quats[i][3]}\n"
            f.write(line)


if __name__ == "__main__":
    radius = 2;
    num_points = 1000
    filename = "S_trajectory.txt"
    #Tt = [sSR t] Ts
    filename_transformed = "S_trajectory_transformed.txt"
    #add drift s S and deltaT between two body frame
    filename_Stransformed = "S_trajectory_Stransformed.txt"

#TODO:noidr change here
    Is_add_drift = True
    #Is_add_drift = False

    trans, quats = generate_S_trajectory(radius, num_points)
    save_tum_trajectory(filename, trans, quats)

    # Transform original S trajectory
    rotation_angle = np.random.uniform(-math.pi/4, math.pi/4)
    translation = np.random.uniform(-1, 1, size=(3,))
    # Apply random rotation
    rotation_vector = np.random.uniform(-1, 1, size=(3,))
    rotation_vector /= np.linalg.norm(rotation_vector)
    rotation = Rotation.from_rotvec(rotation_angle * rotation_vector)
    s = np.random.uniform(0, 1)

    trans_transformed, quats_rotated = transform_trajectory(trans, quats, rotation, translation, s)

    save_tum_trajectory(filename_transformed, trans_transformed, quats_rotated)

    # Transform original ellipse trajectory
    rotation_angle = np.random.uniform(-math.pi/4, math.pi/4)
    translation = np.random.uniform(-5, 5, size=(3,))
    # Apply random rotation
    rotation_vector = np.random.uniform(-1, 1, size=(3,))
    rotation_vector /= np.linalg.norm(rotation_vector)
    rotation = Rotation.from_rotvec(rotation_angle * rotation_vector)
    s = np.random.uniform(0, 1)

    print("R:\n", rotation.as_matrix())
    print("t:\n", translation)
    print("s: ", s)

#Te * T = dT * Tg
    trans_transformed, quats_rotated = Stransform_trajectory(trans, quats, rotation, translation, s, Is_add_drift)

    save_tum_trajectory(filename_Stransformed, trans_transformed, quats_rotated)


